import copy

import networkx as nx
import matplotlib.pyplot as plt

from ModularUtils.ControllerConstants import generate_permutations
from ModularUtils.FunctionsConstant import getdoKey


def set_asia_graph(noise_states, latent_state, obs_state, Data_intervs):
    DAG_desc = "asia_graph"
    Complete_DAG_desc = "asia_graph"

    plot_title="ASIA/LUNG CANCER distribution convergence"
    Observed_DAG = {}


    Observed_DAG["asia"] = []
    Observed_DAG["lung"] = []
    Observed_DAG["tub"] = ["asia"]
    Observed_DAG["either"] = ["tub", "lung"]
    Observed_DAG["xray"] = ["either"]
    Observed_DAG["dysp"] = ["either"]
    obs_state=2

    num_confounders= 1
    Complete_DAG = {}
    for conf in range(num_confounders):
        Complete_DAG["U"+str(conf)] = []

    latent_conf={}
    for var in Observed_DAG:
        Complete_DAG[var]=[]
        latent_conf[var] = []

    confTochild = {"U0": ["lung", "dysp"]}

    for conf in confTochild:
        for var in confTochild[conf]:
            latent_conf[var].append(conf)
            Complete_DAG[var].append(conf)

    for var in Observed_DAG:
        Complete_DAG[var]=Complete_DAG[var]+ Observed_DAG[var]


    # draw_true_graph(Complete_DAG)
    # draw_true_graph(Observed_DAG)


    complete_labels = list(Complete_DAG.keys())



    label_names = list(Observed_DAG.keys())

    image_labels= []
    rep_labels= []

    label_dim = {}

    for label in Observed_DAG.keys():
        label_dim[label] =  obs_state


    for conf in confTochild:
        label_dim[conf] = latent_state

    # Observed_DAG["asia"] = []
    # Observed_DAG["lung"] = []
    # Observed_DAG["tub"] = ["asia"]
    # Observed_DAG["either"] = ["tub", "lung"]
    # Observed_DAG["xray"] = ["either"]
    # Observed_DAG["dysp"] = ["either"]
    intervention_list = [{"expr":"P(asia,tub,either,xray)" ,"obs":['asia','tub','either','xray'], "inter_vars":[]},
                         {"expr": "P(either,lung,dysp)", "obs": ['either', 'lung', 'dysp'], "inter_vars": []},
                         {"expr": "P(V)", "obs": ['asia', 'tub', 'lung', 'either', 'xray', 'dysp'], "inter_vars": []},
                            {"expr":"P(dysp|do(lung))" ,"obs":['dysp'], "inter_vars":['lung']},
                         {"expr":"P(dysp|do(either))" ,"obs":['dysp'], "inter_vars":['either']}
                         # {"expr": "P(dysp|do(asia))", "obs": ['dysp'], "inter_vars": ['asia']}
                         ]

    for lid in range(len(intervention_list)):
        intervention_list[lid]["expr"] = getdoKey(intervention_list[lid]["obs"], intervention_list[lid]["inter_vars"])

    interv_queries = []
    for intervention in intervention_list:
        perms = generate_permutations([label_dim[lb] for lb in intervention["inter_vars"]])
        key_val = [dict(zip(intervention["inter_vars"], comb)) for comb in perms]
        interv_queries.append({"obs": intervention["obs"], "intervs": key_val, "expr": intervention["expr"]})


    cf_queries = []


    exogenous = {}
    for label in label_names:
        if label not in image_labels:
            exogenous[label] = "n" + label


    # counterfactual variables
    cflabel_names = []
    Twin_Network = {}

    cf_exogenous = {}

    cf_intervene = {}
    cf_observe = []
    cf_evidence = {}

    twin_map = {}


    noise_params = {}
    for label in Observed_DAG:
        noise_params["n" + label] = (0.5, noise_states)

    for conf in confTochild:
        noise_params[conf] = (0.1, latent_state)


    train_mech_dict={}


    train_mech_dict["asia"] = [{'parents': [], 'intv': {}, 'compare': ['asia']}]
    train_mech_dict["lung"] = [{'parents': [], 'intv': {}, 'compare': ['lung']}]
    train_mech_dict["tub"] = [{'parents': [], 'intv': {}, 'compare': ['asia', 'tub']}]

    train_mech_dict["either"] = [{'parents': ['lung'], 'intv': {}, 'compare': ['asia', 'tub', 'either']}]
    train_mech_dict["xray"] = [{'parents': [], 'intv': {}, 'compare': ['asia', 'tub', 'either', 'xray']}]
    train_mech_dict["dysp"] = [{'parents': [], 'intv': {}, 'compare': ['either', 'lung', 'dysp']}]
    #compare: joint for which variables are needed. parents: which variables i need to intervene on

    print("printing")
    for label in label_names:
        print(label, train_mech_dict[label])




    for label in Observed_DAG:
        if label not in image_labels:
            label_dim["n" + label] =  noise_states

    return DAG_desc, Complete_DAG_desc, Complete_DAG, complete_labels, Observed_DAG, label_names, image_labels, rep_labels, interv_queries, cf_queries, latent_conf, \
           confTochild, exogenous, cf_intervene, cf_observe, cf_evidence, cflabel_names, twin_map, Twin_Network, cf_exogenous, \
           noise_params, train_mech_dict, label_dim, plot_title



